# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Common routines for working with struct transpiler in bazel files."""

load(
    "//transpiler:fhe_common.bzl",
    "FHE_ENCRYPTION_SCHEMES",
    "executable_attr",
)
load(
    "//transpiler:parsers.bzl",
    "XlsCcOutputInfo",
)

_STRUCT_HEADER_GENERATOR = "//transpiler/struct_transpiler:struct_transpiler"

def _generate_encryption_specific_transpiled_structs_header_path(ctx, library_name, stem, metadata, generic_struct_header, encryption, unwrap = []):
    """Transpile C++ structs/classes into scheme-specific FHE base classes."""
    header_name = stem + "_" + encryption
    if stem != library_name:
        header_name = library_name
    specific_struct_h = ctx.actions.declare_file("%s.types.h" % header_name)

    args = [
        "-metadata_path",
        metadata.path,
        "-output_path",
        specific_struct_h.path,
        "-generic_header_path",
        generic_struct_header.path,
        "-backend_type",
        encryption,
    ]
    if len(unwrap):
        args += [
            "-unwrap",
            ",".join(unwrap),
        ]

    ctx.actions.run(
        inputs = [metadata, generic_struct_header],
        outputs = [specific_struct_h],
        executable = ctx.executable.struct_header_generator,
        arguments = args,
    )

    return specific_struct_h

TranspiledStructsOutputInfo = provider(
    """Provide file attribute representing scheme-specific transpiled-structs header.""",
    fields = {
        "generic_struct_header": "Templates for generic encodings of C++ structs in the source headers. May be None",
        "encryption_specific_transpiled_structs_header": "Scheme-specific encodings of XLScc structs",
    },
)

def _generate_generic_struct_header(ctx, library_name, metadata, hdr_files, unwrap = []):
    """Transpile XLS C++ structs/classes into generic FHE base classes."""
    generic_struct_h = ctx.actions.declare_file("%s.generic.types.h" % library_name)

    args = [
        "-metadata_path",
        metadata.path,
        "-original_headers",
        ",".join([hdr.path for hdr in hdr_files]),
        "-output_path",
        generic_struct_h.path,
    ]
    if len(unwrap):
        args += [
            "-unwrap",
            ",".join(unwrap),
        ]

    ctx.actions.run(
        inputs = [metadata],
        outputs = [generic_struct_h],
        executable = ctx.executable.struct_header_generator,
        arguments = args,
    )

    return generic_struct_h

def _xls_cc_transpiled_structs_impl(ctx):
    src = ctx.attr.src
    encryption = ctx.attr.encryption

    metadata = src[XlsCcOutputInfo].metadata.to_list()[0]

    hdr_files = src[XlsCcOutputInfo].hdr_files
    library_name = src[XlsCcOutputInfo].library_name
    stem = src[XlsCcOutputInfo].stem

    generic_struct_h = _generate_generic_struct_header(
        ctx,
        library_name,
        metadata,
        hdr_files,
        ctx.attr.unwrap,
    )
    generic_struct_headers_depset = depset([generic_struct_h])
    generic_struct_header = generic_struct_headers_depset.to_list()[0]
    specific_struct_h = _generate_encryption_specific_transpiled_structs_header_path(
        ctx,
        library_name,
        stem,
        metadata,
        generic_struct_header,
        encryption,
        ctx.attr.unwrap,
    )

    return [
        DefaultInfo(files = depset([generic_struct_h, specific_struct_h])),
        TranspiledStructsOutputInfo(
            generic_struct_header = depset([generic_struct_h]),
            encryption_specific_transpiled_structs_header = depset([specific_struct_h]),
        ),
    ]

xls_cc_transpiled_structs = rule(
    doc = """
      This rule transpiles C++ structs/classes to generic FHE encodings and produces transpiled C++ 
      code for specified encryption scheme that can be included in other libraries and binaries.
      """,
    implementation = _xls_cc_transpiled_structs_impl,
    attrs = {
        "src": attr.label(
            providers = [XlsCcOutputInfo],
            doc = "A target generated by rule cc_to_xls_ir",
            mandatory = True,
        ),
        "encryption": attr.string(
            doc = """
            FHE encryption scheme used by the resulting program. Choices are
            {tfhe, openfhe, cleartext}. 'cleartext' means the program runs in cleartext,
            skipping encryption; this has zero security, but is useful for debugging.
            """,
            values = FHE_ENCRYPTION_SCHEMES.keys(),
            default = "tfhe",
        ),
        "unwrap": attr.string_list(
            doc = """
            A list of struct names to unwrap.  To unwrap a struct is defined
            only for structs that contain a single field.  When unwrapping a
            struct, its type is replaced by the type of its field.
            """,
        ),
        "struct_header_generator": executable_attr(_STRUCT_HEADER_GENERATOR),
    },
)
